# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

SDXL_ONNX_CONFIG = {
    "down_blocks.0": {
        "dummy_input": {
            "hidden_states": (2, 320, 128, 128),
            "temb": (2, 1280),
        },
        "output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"],
        "dynamic_axes": {
            "hidden_states": {0: "batch_size"},
            "temb": {0: "steps"},
        },
    },
    "down_blocks.1": {
        "dummy_input": {
            "hidden_states": (2, 320, 64, 64),
            "temb": (2, 1280),
            "encoder_hidden_states": (2, 77, 2048),
        },
        "output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"],
        "dynamic_axes": {
            "hidden_states": {0: "batch_size"},
            "temb": {0: "steps"},
            "encoder_hidden_states": {0: "batch_size"},
        },
    },
    "down_blocks.2": {
        "dummy_input": {
            "hidden_states": (2, 640, 32, 32),
            "temb": (2, 1280),
            "encoder_hidden_states": (2, 77, 2048),
        },
        "output_names": ["sample", "res_samples_0", "res_samples_1"],
        "dynamic_axes": {
            "hidden_states": {0: "batch_size"},
            "temb": {0: "steps"},
            "encoder_hidden_states": {0: "batch_size"},
        },
    },
    "mid_block": {
        "dummy_input": {
            "hidden_states": (2, 1280, 32, 32),
            "temb": (2, 1280),
            "encoder_hidden_states": (2, 77, 2048),
        },
        "output_names": ["sample"],
        "dynamic_axes": {
            "hidden_states": {0: "batch_size"},
            "temb": {0: "steps"},
            "encoder_hidden_states": {0: "batch_size"},
        },
    },
    "up_blocks.0": {
        "dummy_input": {
            "hidden_states": (2, 1280, 32, 32),
            "res_hidden_states_0": (2, 640, 32, 32),
            "res_hidden_states_1": (2, 1280, 32, 32),
            "res_hidden_states_2": (2, 1280, 32, 32),
            "temb": (2, 1280),
            "encoder_hidden_states": (2, 77, 2048),
        },
        "output_names": ["sample"],
        "dynamic_axes": {
            "hidden_states": {0: "batch_size"},
            "temb": {0: "steps"},
            "encoder_hidden_states": {0: "batch_size"},
            "res_hidden_states_0": {0: "batch_size"},
            "res_hidden_states_1": {0: "batch_size"},
            "res_hidden_states_2": {0: "batch_size"},
        },
    },
    "up_blocks.1": {
        "dummy_input": {
            "hidden_states": (2, 1280, 64, 64),
            "res_hidden_states_0": (2, 320, 64, 64),
            "res_hidden_states_1": (2, 640, 64, 64),
            "res_hidden_states_2": (2, 640, 64, 64),
            "temb": (2, 1280),
            "encoder_hidden_states": (2, 77, 2048),
        },
        "output_names": ["sample"],
        "dynamic_axes": {
            "hidden_states": {0: "batch_size"},
            "temb": {0: "steps"},
            "encoder_hidden_states": {0: "batch_size"},
            "res_hidden_states_0": {0: "batch_size"},
            "res_hidden_states_1": {0: "batch_size"},
            "res_hidden_states_2": {0: "batch_size"},
        },
    },
    "up_blocks.2": {
        "dummy_input": {
            "hidden_states": (2, 640, 128, 128),
            "res_hidden_states_0": (2, 320, 128, 128),
            "res_hidden_states_1": (2, 320, 128, 128),
            "res_hidden_states_2": (2, 320, 128, 128),
            "temb": (2, 1280),
        },
        "output_names": ["sample"],
        "dynamic_axes": {
            "hidden_states": {0: "batch_size"},
            "temb": {0: "steps"},
            "res_hidden_states_0": {0: "batch_size"},
            "res_hidden_states_1": {0: "batch_size"},
            "res_hidden_states_2": {0: "batch_size"},
        },
    },
}
